Module 3, Sprint 2, Stroke Prediction Dataset¶

Introduction and Goals¶

The aim of this project is to design prototype model and API which could be used by health care providers, doctors and possible even patients themselves to estimate the potential risk of having a stroke based on various demographic and lifestyle variables (this could be further integrated into a more complex system depending on the target user).

Our main goal is to build a system that would allow doctors to perform initial screening which would group their patients into risk groups which would allow them focus on these patients (by providing additional testing, treatment, suggesting lifestyle changes etc..).e.g:

  • Low Risk Patients (i.e. the probality of having a stroke prodicted by our mode is < 35% ), can be mostly ignored.
  • Elevated Risk (35-60%), based on age and other attributes lifestyle changes should be suggested
  • High Risk (60-75%), additional monitoring and testing might be required
  • High Risk (>75%), extensive monitoring and testing (and or treatment for specific health issues)

Core Assumptions:¶

Based on the problem we have defined we'll really on the assumptions and goals when designing our model:

  • The cost of a false positive is significantly higher than the cost of a false negative

    • i.e. classifying a person who is likely to have a stroke as "Low Risk" would significantly decrease the usefulness of our model, which means we should prioritize Recall over Precision
  • The ratio of precision and recall still has to be maintained at a reasonably high level. Considering that the dataset is very imbalanced ( stroke = 1 is only 5% of all samples) we believe that a ratio of around 2 to 3 false positives per true positive is reasonable as long as we can achieve recall of 80-90%.

  • Having this in mind we'll a weighted F1 score (FBeta with various ratio between 1.5 to 5) for hyperparameter tuning.

  • Our secondary objective is to maximize the accuracy of predicted probabilities (i.e. especially we need to make sure that we do not miss-classify any "high risk" patients as "low risk", additional precision above 75% should also be relatively high (at least 2:1 fp to tp because cost of treatment and testing for high risk individual is likely to be very high)

Basically:

  • false positive: cost medium, additional tests and other possibly services will be provided "unnecessarily" to individuals who are have a low risk.
  • false negative: very high costs, would require immediate hospitalization and might result in death

In these are the tresholds above which we believe that our model could be useful in practive:

  • Recall: > 90%
  • Precision: > 25-30%
  • FBeta(b=2.5): > 40%

Additionally

  • above Probability Threshold > .75:
    • Recall: > 95%
    • Precision: > 50%

Project Structure:¶

  • EDA:

    • Analysis of individual features and their relationship
    • Risk Factor Analysis
  • Model Tuning and Testing:

    • Hyper-parameter tuning with various different scoring functions to determine the optimal class weights and other parameters
    • Comparing different types of model and their performance. Factors used to select the best model for a production version of our prototype (in addition to performance)
      • Model complexity. Using a relatively simple model like SVM, Logistic, etc. (or an ensemble model) would be preferable as long as we can achieve comparable performance. However, we do not believe that using a more complex "blackbox" model like XGBoost would be an issue itself because we can still use model importance and generate SHAP plots for individual patients.
      • Computational Efficiency. Using an efficient model (i.e. not Random Forest) would significantly decrease our iteration time.

EDA¶

1.1 Analysis of individuals features and their distributions¶

The charts below show the distribution of all the features included in the dataset:

  1. Numerical features are displayed using a KDE and Boxen plots with additional testing for normality.
  2. Value counts are show for non-numerical features
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image

Most of the feature are self-explanatory, but we should keep in mind these ranges for BMI and glucose levels

BMI:¶

< 18.5 - under 18.5 - 25 - normal 25 - 30 elevated > 30 obese

Glucose:¶

< 80 - low 80 - 100 - normal 100-125 elevated > 125 high

1.2 Relationships Between Features¶

invalid value encountered in format_annotation (vectorized)
No description has been provided for this image
Because the datatypes of features vary we had to use different methods to measure the strength and significance of each pair:

- Chi-Squared Test: Assesses independence between two categorical variables.  For bool-bool pairs due to categorical nature.

- Point Biserial Correlation: Measures correlation between a binary and a continuous variable. For bool-numerical pairs to account for mixed data types.

- Spearman's Rank Correlation: Assesses monotonic relationship between two continuous variables. Used for numerical-numerical pairs (for non-normally distributed data).

Since the Chi-Squared test outputs an unbound statistic/value which can't be directly compared to  pointbiserialr or Spearman Rank we have converted them to a  ` Cramér's V:` value which is normalized between 0 and 1. This was done to make the values in the matrix more uniform however we must note that Cramér's V and Spearman's correlation coefficients are fundamentally different statistics and generally can't be directly compared.
No description has been provided for this image
No description has been provided for this image
Out[16]:
Text(0.5, 1.025, 'Weight and Age')
No description has been provided for this image

The charts below include all the pairs of numerical and non-numerical variables where the groups are signficantly different.

i.e. plots are only rendered in cases where the Null Hypothesis is: "Both groups originate from the same distribution" using the Kruskal–Wallis (one-way ANOVA) non-parametric test.

No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

Risk Factor Analysis¶

In this part we'll look into the relationship between specific risk factors which we would assume to be signficantly related to the likelyhood of having a stroke (both based on correlation and subject knowledge):

  • age
  • hypertension
  • heart_disease
  • avg_glucose_level
  • bmi
  • smoking_status

The KDE plots show the likelihood of having a stroke at a specific age if the patient has any of the listed risk factor (the Y axis is relative to the full sample of individual with the risk factor not just people who have the condition and had a stroke)

No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
<Figure size 1200x600 with 0 Axes>
No description has been provided for this image

We can see that the number of risk factors on average increases until the age of ~60. Afterward it start slightly decreasing. This might is likely a case of survivor bias as most of them tend to have a negative effect on life expectancy.

No description has been provided for this image

This is chart shows individual KDE density curves for each subgroup based on age (it can be interpreted similarly to a histogram).

Interestingly the difference is most prominent below ~65, afterwards the effect of having just 1 or 2 risk factors is much lower.

We can see that people who do are not overweight, do not smoke, do not have elevated glucose levels or heart issues only have a much lower probability of having a stroke as long as they are younger than 60.

No description has been provided for this image

Gennerally most of the risk factors besides having a heart disease seem to have a similar effect below the age of 60, afterwards having diabetes/etc. or hypertension have a much higher effect.

1.5 PCA¶

We have attempted to use PCA to reduce the dimensionality of the dataset.

This might be necessary for datasets which include very high numbers of features. Since this specific dataset is very simple and includes a very low number of columns this was only done for informative/educational purposes.

Additionally, we have included binary/categorical variables which also is generally not advisable in real world cases.

While PCA can be used a preprocessing step (and we have expirmenting with using it for simple logistic or SVM models) this is generally not necessary for simple datasets like this.

No description has been provided for this image
'Total Feature Count: 10'

PCA was done using a Sklrean pipeline which handles standardization for numerical variables.

We can see that the dataset (not including the target variable) could effectively be reduced to 8 components (which preserves about 80% of variance) since this isn't that much lower than the total number of variables it's not particularly useful for ML or even visualization purposes.

2. ML Models¶

We have used various different models . Our process included these steps:

  1. Define separate configurations for each model based on target variables/metrics used for tunning (see src/model_config.py and shared/ml_config_core.py). We have tested these models:
  • XGBoost
  • CatBoost
  • LGBM
  • SVM
  • Random Forest
  • Custom ensemble model (log + SVM + KNN with a soft voting classifier)

Training and validation were performed using Stratified KFolds (5 folds)

  1. Hyperparameter tuning was performed for each model. Because the dataset is heavily imbalanced we have using various different target metrics:
  • macro F1
  • recall (only target class)
  • F1 (only target class)
  • Various

Builtin class weights parameters were used for all the model besides the ensemble one which uses SMOTE, ADASYN, standard oversampling etc. The results for each individual model are stored separately in .tuning_results folder.

Testing and Validation¶

A standard train-validation-test split is not used in any part the analysis. In the initial iterations we have only used Stratified KFolds cross validation for tuning and validation. A 20% split for measuring and visualizing model performance and for implementing the final risk category classification chart was used. However, because the dataset is very small and imbalanced, i.e. there are only ~200 rows with stroke=1 so a test split would only include 40 samples or so which mean that the variance for most metrics would be very high which mean that comparing the performance between models and configuration was somewhat tricky.

Instead, decided to not use a train-validation-split and use the combined averaged results from each of the 5 CV folds as a synthetic "test" sample:

  1. A scikit-learn pipeline with 5 folds is used to measure the performance (precision, accuracy, recall, f1, fbeta etc.) for each fold.
  2. Additionally, The probabilities, predictions and test X,Y values from each of the folds are combined into a single dataset:
    • Basically our "test" split includes all the samples from the dataset with each of the test fold fitted and testing on a separate model.
    • We have found the performance of this approach broadly comparable with using a train-test split overall but it has allowed used to significantly reduced the variance between individual models so splitting the dataset seemed redundant.

Potential issues and limitations:¶

  • Using hyperparameter tuning with the same CV folds might result in indirect data leakage and overfiting
  • To alleviate these potential issues we should consider approaches like nested CV (i.e. one loop for tuning and one for assessing performance
  • Boostraping could be used instead to complement or partially replace our approach (i.e. use CV for tuning and Boostraping for evaluation)
Using balancing config: UnderSamplingConfig
Using <class 'sklearn.model_selection._search.RandomizedSearchCV'> with n_iter=250
Using <class 'sklearn.model_selection._search.RandomizedSearchCV'> with n_iter=250
Out[27]:
best_score best_params search_type model_config_reference
model_key
XGBoostCatF1UndersampleAuto 0.191422 {'model__scale_pos_weight': 1, 'model__n_estimators': 250, 'model__min_child_weight': 1.5, 'model__max_depth': 6, 'model__learning_rate': 0.01, 'model__gamma': 0.3} Random XGBoostCatF1UndersampleAuto(model=<class 'xgboost.sklearn.XGBClassifier'>, supports_nan=True, param_grid={'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1], 'model__max_depth': [4, 5, 6, 7, 10, 12, None], 'model__n_estimators': [50, 100, 150, 200, 250], 'model__min_child_weight': [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5, 3], 'model__gamma': [0, 0.05, 0.1, 0.3, 0.4], 'model__scale_pos_weight': [1, 5, 10, 20, 25, 30, 35, 40]}, builtin_params={'enable_categorical': True}, search_n_iter=250, balancing_config=UnderSamplingConfig(params={}), preprocessing=FunctionTransformer(func=<function preprocessing_for_xgboost.<locals>.convert_to_category at 0x7f29bb4ceb00>), tunning_func_target=make_scorer(f1_score, pos_label=1), best_params=None, ensemble_classifier=None)
XGBoostTuneCatFBeta_25 0.433492 {'model__scale_pos_weight': 25, 'model__n_estimators': 250, 'model__min_child_weight': 1.5, 'model__max_depth': 4, 'model__learning_rate': 0.01, 'model__gamma': 0.1} Random XGBoostTuneCatFBeta_25(model=<class 'xgboost.sklearn.XGBClassifier'>, supports_nan=True, param_grid={'model__learning_rate': [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1], 'model__max_depth': [4, 5, 6, 7, 10, 12, None], 'model__n_estimators': [50, 100, 150, 200, 250], 'model__min_child_weight': [0.1, 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5, 3], 'model__gamma': [0, 0.05, 0.1, 0.3, 0.4], 'model__scale_pos_weight': [1, 5, 10, 20, 25, 30, 35, 40]}, builtin_params={'enable_categorical': True}, search_n_iter=250, balancing_config=None, preprocessing=FunctionTransformer(func=<function preprocessing_for_xgboost.<locals>.convert_to_category at 0x7f29bb4ceb00>), tunning_func_target=make_scorer(fbeta_score, beta=2.5, pos_label=1), best_params=None, ensemble_classifier=None)
Out[28]:
<module 'shared.definitions' from '/home/paulius/data/projects/health_m3_s2/shared/definitions.py'>
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
SVC_SMOTE: 20.7 seconds
LGBMForestBaseConfigTuneFBeta_25: 3.2 seconds
Using balancing config: UnderSamplingConfig
Using balancing config: UnderSamplingConfig
Using balancing config: UnderSamplingConfig
XGBoostCatF1UndersampleAuto: 0.5 seconds
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
Using balancing config: SmoteConfig
Ensemble_Log_KNN_SVM_SMOTE: 30.4 seconds
XGBoostTuneCatFBeta_25: 0.6 seconds
XGBoostTuneCatFBeta_325: 0.6 seconds
XGBoostTuneCatFBeta_40: 0.6 seconds
XGBoostTuneCatFBeta_50: 0.6 seconds
XGBoostTuneRecall: 0.5 seconds
CatBoostBaseConfigTuneFBeta_15: 0.4 seconds
CatBoostBaseConfigTuneFBeta_20: 0.4 seconds
CatBoostBaseConfigTuneFBeta_25: 1.6 seconds
CatBoostBaseConfigTuneFBeta_325: 0.4 seconds
CatBoostBaseConfigTuneFBeta_40: 0.4 seconds
CatBoostBaseConfigTuneRecall: 0.5 seconds
Results¶

The table below shows the results for each configuration using the optimal parameters:

Out[41]:
accuracy precision_macro recall_macro f1_macro target_f1 target_recall target_precision fbeta_1.5 fbeta_2.5 fbeta_4.0 n_samples
XGBoostTuneCatFBeta_25 0.729 0.550 0.758 0.518 0.199 0.789 0.114 0.279 0.434 0.585 4908.0
XGBoostCatF1UndersampleAuto 0.712 0.548 0.754 0.508 0.191 0.799 0.109 0.270 0.426 0.582 4908.0
XGBoostTuneRecall 0.717 0.542 0.715 0.503 0.177 0.713 0.101 0.249 0.388 0.525 4908.0
SVC_SMOTE 0.825 0.541 0.641 0.539 0.176 0.440 0.110 0.229 0.312 0.374 4908.0
Ensemble_Log_KNN_SVM_SMOTE 0.844 0.544 0.635 0.548 0.182 0.407 0.117 0.231 0.303 0.355 4908.0
XGBoostTuneCatFBeta_325 0.897 0.561 0.619 0.576 0.207 0.316 0.153 0.238 0.276 0.297 4908.0
XGBoostTuneCatFBeta_40 0.897 0.561 0.619 0.576 0.207 0.316 0.153 0.238 0.276 0.297 4908.0
XGBoostTuneCatFBeta_50 0.897 0.561 0.619 0.576 0.207 0.316 0.153 0.238 0.276 0.297 4908.0
CatBoostBaseConfigTuneFBeta_25 0.707 0.518 0.593 0.472 0.120 0.469 0.069 0.168 0.260 0.349 4908.0
CatBoostBaseConfigTuneFBeta_15 0.710 0.518 0.592 0.473 0.120 0.464 0.069 0.168 0.259 0.347 4908.0
CatBoostBaseConfigTuneFBeta_20 0.710 0.518 0.592 0.473 0.120 0.464 0.069 0.168 0.259 0.347 4908.0
LGBMForestBaseConfigTuneFBeta_25 0.330 0.511 0.557 0.281 0.093 0.804 0.049 0.141 0.258 0.423 4908.0
CatBoostBaseConfigTuneFBeta_325 0.360 0.510 0.554 0.299 0.092 0.766 0.049 0.140 0.254 0.412 4908.0
CatBoostBaseConfigTuneFBeta_40 0.360 0.510 0.554 0.299 0.092 0.766 0.049 0.140 0.254 0.412 4908.0
CatBoostBaseConfigTuneRecall 0.360 0.510 0.554 0.299 0.092 0.766 0.049 0.140 0.254 0.412 4908.0
The figure layout has changed to tight
No description has been provided for this image

The behaviour of the precision-recall curve for all models indicates both very poor performance (precision is very low at all thresholds). Additionally, the curves are all:

  • non-monotonic, i.e., they change direction on the Y axis several times as the threshold is changed, due to fluctuating true and false positives.
  • precision quickly drops (even at very low thresholds) and varies significantly due to the model's inability to consistently identify the sparse positive cases in the heavily imbalanced dataset.
No description has been provided for this image

Selecting the "Best" Model¶

We have been able to get relatively comparable results with all the complex boost model and our ensemble model performs relatively similarly as long as some oversampling technique like SMOTE is used. With additional tuning it might provide effectively the same performance as XGBoost or CatBoost. However, the training of the (LogisticRegression + KNeighborsClassifier + SVC) is very slow, so it would still be much more practical to use complex model which handles balancing etc. directly.

As far as perfomance as hyperparemeter tunning goes the only parameter that really matters is class weight which directly affects the recall / precision ratio (based on our select fbeta value for scoring).

Having that mind we have selected: XGBoostTuneCatFBeta_25 as our production model, while it's overal performance is not ideal it still provides reasonable performances relative to your assumptions outlayed previously. XX% recall relative to XX% precision means that for every person with stroke=1 we will also select ~{N} individuals as "high risk"

Model Feature Importance and SHAP plots¶

We'll use SHAP values to further analyze the importance of each feature:

set_ticklabels() should only be used with a fixed number of ticks, i.e. after set_ticks() or using a FixedLocator.
No description has been provided for this image
No description has been provided for this image

Probability Thresholds¶

An approach that might mitigate the precision / recall issue is to further split the risk group identifed by our model into separate "Low", "Medium", "High" risk categories which would allow us to more effectively use the resources by giving more focus ot individuals who have the highest risk:

Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.
No description has been provided for this image

The chart shows the performance of the if only individual with stroke Prob. > T are selected. Additionally the overlay indicates the number of people whose predicted P is in an given range. The overlays can be used to selected the most at risk individual based on the probability predicted for them

Conclusion¶

  • We have tried multiple different ML models to predict the insurance columns
  • While the overall performance is reasonable good (F1 > 0.8) the model underestimates the TravelInsurance = True class
    • This is a big issue for our client because we can only identify around 60% of all potential clients.
    • On the positive our model is very good at identifying people who don't need travel insurance (almost 95% in the best case) which means that we can only contact the people who are likely to buy it which results in very high efficiency of our sales team.

Limitations and Suggestions for Future Improvements:¶

Business Case/Interpretation¶
  • A deeper cost based analysis should be performed (ideally including based on data from specific insurance companies/government healthcare systems/etc.) to determine the acceptable precision/recall ratio. While the direct and indirect cost of an individual suffering a stroke might be high:
    • It's not clear what real benefits identifying individual stroke victims provides. If it's mostly related to lifestyle choices additional treatment and monitoring would not be particularly useful if the patients are unwilling to alter their lifestyles.
    • Potentially this model can be used on an app targeting consumers for self identification purposes (i.e. to alter lifestyle choices)
Technical¶
  • Tunning for 'log_loss' instead of a classification metric.
  • Tweaking the threshold and using that while hyper-parameter tuning might be beneficial: -
  • Using AUCPR for tuning
  • Over-fitting hyper-parameters like 'early_stopping_rounds' can be utilized to cut model training early {TODO}